{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Dynamic CNN for Modeling Sentences (Kalchbrenner et al 2014)\n",
    "- Simple CNN structures for modeling sentences such as that used in Kim 2014 sometimes suffer from long distance dependencies \n",
    "- Hence, Kalchbrenner et al have suggested **Dynamic Convolutional Neural Network (DCNN)** in their 2014 paper (Kalchbrenner, N., Grefenstette, E., & Blunsom, P. (2014). A convolutional neural network for modelling sentences. arXiv preprint arXiv:1404.2188.)\n",
    "    - \"The network uses Dynamic k-Max Pooling, a global pooling operation over linear sequences. The network handles input sentences of varying length and induces a feature graph over the sentence that is capable of explicitly capturing short and long-range relations.\"\n",
    "    \n",
    "</br>\n",
    "<img src=\"https://www.aclweb.org/anthology/P/P14/P14-1062/image003.png\" style=\"width: 400px\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from keras.preprocessing import sequence\n",
    "from keras.models import *\n",
    "from keras.layers import *\n",
    "from keras.callbacks import *\n",
    "from keras.datasets import imdb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining k-maxpooling layer\n",
    "- As there does not exist k-maxpooling layer in Keras, it is defined as custom layer\n",
    "- source: https://github.com/keras-team/keras/issues/373"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.engine import Layer, InputSpec\n",
    "import tensorflow as tf\n",
    "\n",
    "class KMaxPooling(Layer):\n",
    "    \"\"\"\n",
    "    K-max pooling layer that extracts the k-highest activations from a sequence (2nd dimension).\n",
    "    TensorFlow backend.\n",
    "    \"\"\"\n",
    "    def __init__(self, k=1, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.input_spec = InputSpec(ndim=3)\n",
    "        self.k = k\n",
    "\n",
    "    def compute_output_shape(self, input_shape):\n",
    "        return (input_shape[0], (input_shape[1] * self.k))\n",
    "\n",
    "    def call(self, inputs):\n",
    "        \n",
    "        # swap last two dimensions since top_k will be applied along the last dimension\n",
    "        #shifted_input = tf.transpose(inputs, [0, 2, 1])\n",
    "        \n",
    "        # extract top_k, returns two tensors [values, indices]\n",
    "        top_k = tf.nn.top_k(inputs, k=self.k, sorted=True, name=None)[0]\n",
    "        \n",
    "        # return flattened output\n",
    "        return top_k"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Load Dataset\n",
    "- Load ```imdb``` dataset in Keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_words = 5000\n",
    "max_len = 300\n",
    "embedding_dim = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(25000, 300) (25000, 300) (25000,) (25000,)\n"
     ]
    }
   ],
   "source": [
    "(X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=num_words)\n",
    "\n",
    "X_train = sequence.pad_sequences(X_train, maxlen=max_len)\n",
    "X_test = sequence.pad_sequences(X_test, maxlen=max_len)\n",
    "\n",
    "print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Basic Dynamic CNN\n",
    "- Basic DCNN structure with only one feature map and one covolution/pooling layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def basic_dynamic_cnn(k = 5):\n",
    "    model = Sequential()\n",
    "    # Embedding each word\n",
    "    model.add(Embedding(num_words, embedding_dim, input_length = max_len))\n",
    "    # Wide convolution\n",
    "    model.add(ZeroPadding1D(29))\n",
    "    model.add(Conv1D(embedding_dim, 30, activation = 'relu'))\n",
    "    # k-max pooling\n",
    "    model.add(Permute((2, 1)))\n",
    "    model.add(KMaxPooling(k))\n",
    "    model.add(Reshape((k, -1)))\n",
    "    model.add(Flatten())\n",
    "    model.add(Dense(1, activation = 'sigmoid'))\n",
    "    \n",
    "    model.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['accuracy'])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_13 (Embedding)     (None, 300, 50)           250000    \n",
      "_________________________________________________________________\n",
      "zero_padding1d_13 (ZeroPaddi (None, 358, 50)           0         \n",
      "_________________________________________________________________\n",
      "conv1d_12 (Conv1D)           (None, 329, 50)           75050     \n",
      "_________________________________________________________________\n",
      "permute_7 (Permute)          (None, 50, 329)           0         \n",
      "_________________________________________________________________\n",
      "k_max_pooling_9 (KMaxPooling (None, 250)               0         \n",
      "_________________________________________________________________\n",
      "reshape_7 (Reshape)          (None, 5, 50)             0         \n",
      "_________________________________________________________________\n",
      "flatten_5 (Flatten)          (None, 250)               0         \n",
      "_________________________________________________________________\n",
      "dense_5 (Dense)              (None, 1)                 251       \n",
      "=================================================================\n",
      "Total params: 325,301\n",
      "Trainable params: 325,301\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "basic_dynamic_cnn = basic_dynamic_cnn()\n",
    "basic_dynamic_cnn.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "callbacks = [ModelCheckpoint(filepath = 'best_model.hdf5', monitor='val_acc', verbose=1, save_best_only = True, mode='max')]\n",
    "history = basic_dynamic_cnn.fit(X_train, y_train, callbacks = callbacks, epochs = 10, validation_split = 0.2, batch_size = 200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "24832/25000 [============================>.] - ETA: 0sTest accuracy:  0.87084\n"
     ]
    }
   ],
   "source": [
    "basic_dynamic_cnn_best_model = basic_dynamic_cnn()\n",
    "basic_dynamic_cnn_best_model.load_weights('best_model.hdf5')\n",
    "basic_dynamic_cnn_best_model.compile(optimizer = 'Adam', loss = 'binary_crossentropy', metrics = ['accuracy'])\n",
    "results = basic_dynamic_cnn_best_model.evaluate(X_test, y_test)\n",
    "print('Test accuracy: ', results[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def basic_dynamic_cnn(k = 5):\n",
    "    model = Sequential()\n",
    "    # Embedding each word\n",
    "    model.add(Embedding(num_words, embedding_dim, input_length = max_len))\n",
    "    # Wide convolution\n",
    "    model.add(ZeroPadding1D(29))\n",
    "    model.add(Conv1D(embedding_dim, 30, activation = 'relu'))\n",
    "    # k-max pooling\n",
    "    model.add(Permute((2, 1)))\n",
    "    model.add(KMaxPooling(k))\n",
    "    model.add(Reshape((k, -1)))\n",
    "    model.add(Flatten())\n",
    "    model.add(Dense(1, activation = 'sigmoid'))\n",
    "    \n",
    "    model.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['accuracy'])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. DCNN with two convolution (pooling) layers\n",
    "- Perform wide convolution and k-max pooling two times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "# two kinds of k's and kernel sizes for each operation\n",
    "def two_conv_dynamic_cnn(k1 = 20, k2 = 10, ksize1 = 20, ksize = 30):\n",
    "    inputs = Input(shape = (X_train.shape[-1],))\n",
    "    embed = Embedding(num_words, embedding_dim, input_length = max_len)(inputs)\n",
    "    padded = ZeroPadding1D(ksize1 - 1)(embed)\n",
    "    conv1 = Conv1D(embedding_dim, ksize1, activation = 'relu')(padded)\n",
    "    permuted = Permute((2,1))(conv1)\n",
    "    kmaxpool1 = KMaxPooling(k1)(permuted)\n",
    "    kmaxpool1 = Reshape((k1, -1))(kmaxpool1)\n",
    "    padded = ZeroPadding1D(ksize2 -1)(kmaxpool1)\n",
    "    conv2 = Conv1D(embedding_dim, ksize2, activation = 'relu')(padded)\n",
    "    permuted = Permute((2,1))(conv2)\n",
    "    kmaxpool2 = KMaxPooling(k2)(permuted)\n",
    "    kmaxpool2 = Reshape((k2, -1))(kmaxpool2)\n",
    "    flattened = Flatten()(kmaxpool2)\n",
    "    outputs = Dense(1, activation = 'sigmoid')(flattened)\n",
    "    \n",
    "    model = Model(inputs = inputs, outputs = outputs)\n",
    "    model.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['accuracy'])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_13 (InputLayer)        (None, 300)               0         \n",
      "_________________________________________________________________\n",
      "embedding_28 (Embedding)     (None, 300, 50)           250000    \n",
      "_________________________________________________________________\n",
      "zero_padding1d_30 (ZeroPaddi (None, 338, 50)           0         \n",
      "_________________________________________________________________\n",
      "conv1d_29 (Conv1D)           (None, 319, 50)           50050     \n",
      "_________________________________________________________________\n",
      "permute_21 (Permute)         (None, 50, 319)           0         \n",
      "_________________________________________________________________\n",
      "k_max_pooling_23 (KMaxPoolin (None, 1000)              0         \n",
      "_________________________________________________________________\n",
      "reshape_21 (Reshape)         (None, 20, 50)            0         \n",
      "_________________________________________________________________\n",
      "zero_padding1d_31 (ZeroPaddi (None, 78, 50)            0         \n",
      "_________________________________________________________________\n",
      "conv1d_30 (Conv1D)           (None, 49, 50)            75050     \n",
      "_________________________________________________________________\n",
      "permute_22 (Permute)         (None, 50, 49)            0         \n",
      "_________________________________________________________________\n",
      "k_max_pooling_24 (KMaxPoolin (None, 500)               0         \n",
      "_________________________________________________________________\n",
      "reshape_22 (Reshape)         (None, 10, 50)            0         \n",
      "_________________________________________________________________\n",
      "flatten_11 (Flatten)         (None, 500)               0         \n",
      "_________________________________________________________________\n",
      "dense_11 (Dense)             (None, 1)                 501       \n",
      "=================================================================\n",
      "Total params: 375,601\n",
      "Trainable params: 375,601\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "two_conv_dynamic_cnn = two_conv_dynamic_cnn()\n",
    "two_conv_dynamic_cnn.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "callbacks = [ModelCheckpoint(filepath = 'best_model.hdf5', monitor='val_acc', verbose=1, save_best_only = True, mode='max')]\n",
    "history = two_conv_dynamic_cnn.fit(X_train, y_train, callbacks = callbacks, epochs = 10, validation_split = 0.2, batch_size = 200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "24960/25000 [============================>.] - ETA: 0sTest accuracy:  0.8832\n"
     ]
    }
   ],
   "source": [
    "two_conv_dynamic_cnn_best_model = two_conv_dynamic_cnn()\n",
    "two_conv_dynamic_cnn_best_model.load_weights('best_model.hdf5')\n",
    "two_conv_dynamic_cnn_best_model.compile(optimizer = 'Adam', loss = 'binary_crossentropy', metrics = ['accuracy'])\n",
    "results = two_conv_dynamic_cnn_best_model.evaluate(X_test, y_test)\n",
    "print('Test accuracy: ', results[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. DCNN with two feature maps\n",
    "- DCNN with two feature maps (and two convolutions), concatenated at the end of model\n",
    "- This implementation is most close to the original model delineated in paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def two_feature_map_dynamic_cnn(k1 = 20, k2 = 10, ksize1 = 20, ksize = 30):\n",
    "    inputs = Input(shape = (X_train.shape[-1],))\n",
    "    embed = Embedding(num_words, embedding_dim, input_length = max_len)(inputs)\n",
    "    conv_results = []\n",
    "    # two feature maps using for loop\n",
    "    for i in range(2):\n",
    "        padded = ZeroPadding1D(ksize1 - 1)(embed)\n",
    "        conv1 = Conv1D(embedding_dim, ksize1, activation = 'relu')(padded)\n",
    "        permuted = Permute((2,1))(conv1)\n",
    "        kmaxpool1 = KMaxPooling(k1)(permuted)\n",
    "        kmaxpool1 = Reshape((k1, -1))(kmaxpool1)\n",
    "        padded = ZeroPadding1D(ksize2 -1)(kmaxpool1)\n",
    "        conv2 = Conv1D(embedding_dim, ksize2, activation = 'relu')(padded)\n",
    "        permuted = Permute((2,1))(conv2)\n",
    "        kmaxpool2 = KMaxPooling(k2)(permuted)\n",
    "        kmaxpool2 = Reshape((k2, -1))(kmaxpool2)\n",
    "        flattened = Flatten()(kmaxpool2)\n",
    "        conv_results.append(flattened)\n",
    "    conv_result = concatenate(conv_results)\n",
    "    outputs = Dense(1, activation = 'sigmoid')(conv_result)\n",
    "    \n",
    "    model = Model(inputs = inputs, outputs = outputs)\n",
    "    model.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['accuracy'])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_15 (InputLayer)           (None, 300)          0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding_30 (Embedding)        (None, 300, 50)      250000      input_15[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "zero_padding1d_34 (ZeroPadding1 (None, 338, 50)      0           embedding_30[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "zero_padding1d_36 (ZeroPadding1 (None, 338, 50)      0           embedding_30[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "conv1d_33 (Conv1D)              (None, 319, 50)      50050       zero_padding1d_34[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv1d_35 (Conv1D)              (None, 319, 50)      50050       zero_padding1d_36[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "permute_25 (Permute)            (None, 50, 319)      0           conv1d_33[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "permute_27 (Permute)            (None, 50, 319)      0           conv1d_35[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "k_max_pooling_27 (KMaxPooling)  (None, 1000)         0           permute_25[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "k_max_pooling_29 (KMaxPooling)  (None, 1000)         0           permute_27[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "reshape_25 (Reshape)            (None, 20, 50)       0           k_max_pooling_27[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "reshape_27 (Reshape)            (None, 20, 50)       0           k_max_pooling_29[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "zero_padding1d_35 (ZeroPadding1 (None, 78, 50)       0           reshape_25[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "zero_padding1d_37 (ZeroPadding1 (None, 78, 50)       0           reshape_27[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "conv1d_34 (Conv1D)              (None, 49, 50)       75050       zero_padding1d_35[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv1d_36 (Conv1D)              (None, 49, 50)       75050       zero_padding1d_37[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "permute_26 (Permute)            (None, 50, 49)       0           conv1d_34[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "permute_28 (Permute)            (None, 50, 49)       0           conv1d_36[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "k_max_pooling_28 (KMaxPooling)  (None, 500)          0           permute_26[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "k_max_pooling_30 (KMaxPooling)  (None, 500)          0           permute_28[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "reshape_26 (Reshape)            (None, 10, 50)       0           k_max_pooling_28[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "reshape_28 (Reshape)            (None, 10, 50)       0           k_max_pooling_30[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "flatten_13 (Flatten)            (None, 500)          0           reshape_26[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "flatten_14 (Flatten)            (None, 500)          0           reshape_28[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_1 (Concatenate)     (None, 1000)         0           flatten_13[0][0]                 \n",
      "                                                                 flatten_14[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dense_13 (Dense)                (None, 1)            1001        concatenate_1[0][0]              \n",
      "==================================================================================================\n",
      "Total params: 501,201\n",
      "Trainable params: 501,201\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "two_feature_map_dynamic_cnn = two_feature_map_dynamic_cnn()\n",
    "two_feature_map_dynamic_cnn.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "callbacks = [ModelCheckpoint(filepath = 'best_model.hdf5', monitor='val_acc', verbose=1, save_best_only = True, mode='max')]\n",
    "history = two_feature_map_dynamic_cnn.fit(X_train, y_train, callbacks = callbacks, epochs = 10, validation_split = 0.2, batch_size = 200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "24960/25000 [============================>.] - ETA: 0sTest accuracy:  0.88092\n"
     ]
    }
   ],
   "source": [
    "two_feature_map_dynamic_cnn_best_model = two_feature_map_dynamic_cnn()\n",
    "two_feature_map_dynamic_cnn_best_model.load_weights('best_model.hdf5')\n",
    "two_feature_map_dynamic_cnn_best_model.compile(optimizer = 'Adam', loss = 'binary_crossentropy', metrics = ['accuracy'])\n",
    "results = two_feature_map_dynamic_cnn_best_model.evaluate(X_test, y_test)\n",
    "print('Test accuracy: ', results[1])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
